diff --git a/megatron/arguments.py b/megatron/arguments.py
index 26a7cec..0acfb22 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -21,6 +21,8 @@ import os
 import torch
 from megatron import fused_kernels

+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
 def parse_args(extra_args_provider=None, defaults={},
                ignore_unknown_args=False):
     """Parse all arguments."""
@@ -40,6 +42,7 @@ def parse_args(extra_args_provider=None, defaults={},
     parser = _add_data_args(parser)
     parser = _add_autoresume_args(parser)
     parser = _add_realm_args(parser)
+    parser = _add_fmoe_args(parser)

     # Custom arguments.
     if extra_args_provider is not None:
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 9d42260..2583db2 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -177,6 +177,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
                                                                   param)
                         if hasattr(param, 'shared'):
                             main_param.shared = param.shared
+                        if hasattr(param, 'dp_comm'):
+                            main_param.dp_comm = param.dp_comm
                         # Replace the optimizer params with the new fp32 copy.
                         param_group['params'][i] = main_param
                         fp32_from_fp16_params_this_group.append(main_param)
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..f825bf3 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,20 +35,24 @@ from megatron import update_num_microbatches
 from megatron import mpu
 from megatron import print_rank_0
 from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
 from megatron.model import FP16Module
 from megatron.optimizer import get_megatron_optimizer

 from megatron.initialize import initialize_megatron
 from megatron.initialize import write_args_to_tensorboard
 from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
 from megatron.model.realm_model import ICTBertModel
 from megatron.utils import check_adlr_autoresume_termination
 from megatron.data.data_loaders import build_pretraining_data_loader
 from megatron.utils import report_memory

+from fmoe.megatron import DistributedDataParallel as LocalDDP
+from fmoe.megatron import add_balance_log

 def print_datetime(string):
     """Note that this call will sync across all ranks."""
@@ -102,6 +106,13 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
     args = get_args()
     timers = get_timers()

+    # Initialize FastMoE
+    if args.fmoefy:
+        from fmoe.megatron import patch_forward_step, patch_model_provider
+
+        forward_step_func = patch_forward_step(forward_step_func)
+        model_provider = patch_model_provider(model_provider)
+
     # Model, optimizer, and learning rate.
     timers('model and optimizer').start()
     model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -643,7 +654,7 @@ def train_step(forward_step_func, data_iterator,


 def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
-                 loss_scale, report_memory_flag, skipped_iter):
+                 loss_scale, report_memory_flag, skipped_iter, model):
     """Log training information such as losses, timing, ...."""
     args = get_args()
     timers = get_timers()
@@ -725,6 +736,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
                           args.consumed_train_samples)
         timers.write(timers_to_log, writer, iteration,
                      normalizer=total_iterations)
+    if args.fmoefy and args.balance_strategy and args.balance_strategy != 'naive':
+        add_balance_log(model, writer, iteration)

     if iteration % args.log_interval == 0:
         elapsed_time = timers('interval time').elapsed()
@@ -816,7 +829,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
         report_memory_flag = training_log(loss_dict, total_loss_dict,
                                           optimizer.param_groups[0]['lr'],
                                           iteration, loss_scale,
-                                          report_memory_flag, skipped_iter)
+                                          report_memory_flag, skipped_iter, model)

         # Autoresume
         if args.adlr_autoresume and \
